# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import threading
import time
from collections import defaultdict
from functools import partial
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer

from ..core._imperative_rt.utils import create_mm_server
from ..utils.future import Future


class Methods:
    r"""Distributed Server Method.
    Used for exchange information between distributed nodes.

    Args:
        mm_server_port: multiple machine rpc server port.
    """

    def __init__(self, mm_server_port):
        self.lock = threading.Lock()
        self.mm_server_port = mm_server_port
        self.dict_is_grad = defaultdict(partial(Future, True))
        self.dict_remote_tracer = defaultdict(partial(Future, True))
        self.dict_pack_list = defaultdict(partial(Future, False))
        self.dict_barrier_counter = defaultdict(int)
        self.dict_barrier_event = defaultdict(threading.Event)
        self.user_dict = defaultdict(partial(Future, False))
        self.bcast_dict = {}

    def connect(self):
        r"""Method for checking connection success."""
        return True

    def get_mm_server_port(self):
        r"""Get multiple machine rpc server port."""
        return self.mm_server_port

    def set_is_grad(self, key, is_grad):
        r"""Mark send/recv need gradiants by key.

        Args:
            key: key to match send/recv op.
            is_grad: whether this op need grad.
        """
        with self.lock:
            future = self.dict_is_grad[key]
        future.set(is_grad)
        return True

    def check_is_grad(self, key):
        r"""Check whether send/recv need gradiants.

        Args:
            key: key to match send/recv op.
        """
        with self.lock:
            future = self.dict_is_grad[key]
        ret = future.get()
        with self.lock:
            del self.dict_is_grad[key]
        return ret

    def set_remote_tracer(self, key, tracer_set):
        r"""Set tracer dict for tracing send/recv op.

        Args:
            key: key to match send/recv op.
            tracer_set: valid tracer set.
        """
        with self.lock:
            future = self.dict_remote_tracer[key]
        future.set(tracer_set)
        return True

    def check_remote_tracer(self, key):
        r"""Get tracer dict for send/recv op.

        Args:
            key: key to match send/recv op.
        """
        with self.lock:
            future = self.dict_remote_tracer[key]
        ret = future.get()
        with self.lock:
            del self.dict_remote_tracer[key]
        return ret

    def group_barrier(self, key, size):
        r"""A barrier wait for all group member.

        Args:
            key: group key to match each other.
            size: group size.
        """
        with self.lock:
            self.dict_barrier_counter[key] += 1
            counter = self.dict_barrier_counter[key]
            event = self.dict_barrier_event[key]
        if counter == size:
            del self.dict_barrier_counter[key]
            del self.dict_barrier_event[key]
            event.set()
        else:
            event.wait()
        return True

    def user_set(self, key, val):
        r"""Set user defined key-value pairs across processes."""
        with self.lock:
            future = self.user_dict[key]
        future.set(val)
        return True

    def user_get(self, key):
        r"""Get user defined key-value pairs across processes."""
        with self.lock:
            future = self.user_dict[key]
        return future.get()

    def bcast_val(self, val, key, size):
        with self.lock:
            if key not in self.bcast_dict:
                self.bcast_dict[key] = [Future(False), size]
            arr = self.bcast_dict[key]
        if val is not None:
            arr[0].set(val)
            val = None
        else:
            val = arr[0].get()
        with self.lock:
            cnt = arr[1] - 1
            arr[1] = cnt
            if cnt == 0:
                del self.bcast_dict[key]
        return val

    def _del(self, key):
        with self.lock:
            del self.user_dict[key]

    # thread safe function
    def user_pop(self, key):
        ret = self.user_get(key)
        self._del(key)
        return ret


class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
    pass


def _start_server(py_server_port, queue):
    r"""Start python distributed server and multiple machine server.

    Args:
        py_server_port: python server port.
        mm_server_port: multiple machine server port.
        queue: server port will put in this queue, puts exception when process fails.
    """
    try:
        mm_server_port = create_mm_server("0.0.0.0", 0)
        server = ThreadXMLRPCServer(
            ("0.0.0.0", py_server_port), logRequests=False, allow_none=True
        )
        server.register_instance(Methods(mm_server_port))
        _, py_server_port = server.server_address
        queue.put((py_server_port, mm_server_port))
        server.serve_forever()
    except Exception as e:
        queue.put(e)


class Server:
    r"""Distributed Server for distributed training.
    Should be running at master node.

    Args:
        port: python server port.
    """

    def __init__(self, port=0):
        q = mp.Queue()
        self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True)
        self.proc.start()
        ret = q.get()
        if isinstance(ret, Exception):
            raise ret
        else:
            self.py_server_port, self.mm_server_port = ret

    def __del__(self):
        self.proc.terminate()


class Client:
    r"""Distributed Client for distributed training.

    Args:
        master_ip: ip address of master node.
        port: port of server at master node.
    """

    def __init__(self, master_ip, port):
        self.master_ip = master_ip
        self.port = port
        self.connect()
        self.bcast_dict = defaultdict(lambda: 0)

    def connect(self):
        r"""Check connection success."""
        while True:
            try:
                self.proxy = ServerProxy(
                    "http://{}:{}".format(self.master_ip, self.port), allow_none=True
                )
                if self.proxy.connect():
                    break
            except:
                time.sleep(1)

    def get_mm_server_port(self):
        r"""Get multiple machine server port."""
        while True:
            try:
                return self.proxy.get_mm_server_port()
            except:
                time.sleep(0.5)

    def set_is_grad(self, key, is_grad):
        r"""Mark send/recv need gradiants by key.

        Args:
            key: key to match send/recv op.
            is_grad: whether this op need grad.
        """
        self.proxy.set_is_grad(key, is_grad)

    def check_is_grad(self, key):
        r"""Check whether send/recv need gradiants.

        Args:
            key: key to match send/recv op.
        """
        return self.proxy.check_is_grad(key)

    def set_remote_tracer(self, key, tracer_set):
        r"""Set tracer dict for tracing send/recv op.

        Args:
            key: key to match send/recv op.
            tracer_set: valid tracer set.
        """
        self.proxy.set_remote_tracer(key, tracer_set)

    def check_remote_tracer(self, key):
        r"""Get tracer dict for send/recv op.

        Args:
            key: key to match send/recv op.
        """
        return self.proxy.check_remote_tracer(key)

    def group_barrier(self, key, size):
        r"""A barrier wait for all group member.

        Args:
            key: group key to match each other.
            size: group size.
        """
        # FIXME: group_barrier is not idempotent
        while True:
            try:
                self.proxy.group_barrier(key, size)
                return
            except:
                time.sleep(0.5)

    def user_set(self, key, val):
        r"""Set user defined key-value pairs across processes."""
        return self.proxy.user_set(key, val)

    def user_get(self, key):
        r"""Get user defined key-value pairs across processes."""
        return self.proxy.user_get(key)

    def user_pop(self, key):
        r"""Get user defined key-value pairs and delete the resources when the get is done"""
        return self.proxy.user_pop(key)

    def bcast_val(self, val, key, size):
        idx = self.bcast_dict[key] + 1
        self.bcast_dict[key] = idx
        key = key + "_bcast_" + str(idx)
        return self.proxy.bcast_val(val, key, size)


def main(port=0, verbose=True):
    mm_server_port = create_mm_server("0.0.0.0", 0)
    server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
    server.register_instance(Methods(mm_server_port))
    _, port = server.server_address
    print("serving on port", port)
    server.serve_forever()


if __name__ == "__main__":
    import argparse

    ap = argparse.ArgumentParser()
    ap.add_argument("-p", "--port", type=int, default=0)
    ap.add_argument("-v", "--verbose", type=bool, default=True)
    args = ap.parse_args()
    main(port=args.port, verbose=args.verbose)
